#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   get_data_resnet.py    
@Contact :   raogx.vip@hotmail.com
@License :   (C)Copyright 2017-2018, Liugroup-NLPR-CASIA

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2021/12/13 11:03 下午   caijiahao      1.0         Bytedancer
'''

# import lib
import os
import shutil


# 拆分数据
def get_address():
    """获取所有图片路径"""
    data_file = os.listdir('./train')

    dog_file = list(filter(lambda x: x[:3] == 'dog', data_file))
    cat_file = list(filter(lambda x: x[:3] == 'cat', data_file))

    root = os.getcwd()

    return dog_file, cat_file, root


def arrange():
    """整理数据，移动图片位置"""
    dog_file, cat_file, root = get_address()

    print('开始数据整理')
    # 新建文件夹
    for i in ['dog', 'cat']:
        for j in ['train', 'val']:
            try:
                os.makedirs(os.path.join(root, j, i))
            except FileExistsError as e:
                pass

    # 移动10%(1250)的狗图到验证集
    for i, file in enumerate(dog_file):
        ori_path = os.path.join(root, 'train', file)
        if i < 0.9 * len(dog_file):
            des_path = os.path.join(root, 'train', 'dog')
        else:
            des_path = os.path.join(root, 'val', 'dog')
        shutil.move(ori_path, des_path)

    # 移动10%(1250)的猫图到验证集
    for i, file in enumerate(cat_file):
        ori_path = os.path.join(root, 'train', file)
        if i < 0.9 * len(cat_file):
            des_path = os.path.join(root, 'train', 'cat')
        else:
            des_path = os.path.join(root, 'val', 'cat')
        shutil.move(ori_path, des_path)
    print('数据整理完成')


# 转为可读入数据
def get_data(input_size, batch_size):
    """获取文件数据并转换"""
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader

    # 串联多个图片变换的操作（训练集）
    # transforms.RandomResizedCrop(input_size) 先随机采集，然后对裁剪得到的图像缩放为同一大小
    # RandomHorizontalFlip()  以给定的概率随机水平旋转给定的PIL的图像
    # transforms.ToTensor()  将图片转换为Tensor,归一化至[0,1]
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  归一化处理(平均数，标准偏差)
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # 获取训练集（通过上面的方面操作）
    train_set = ImageFolder('train', transform=transform_train)
    # 封装训练集
    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True)

    # 串联多个图片变换的操作（验证集）
    transform_val = transforms.Compose([
        transforms.Resize([input_size, input_size]),  # 注意 Resize 参数是 2 维，和 RandomResizedCrop 不同
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # 获取验证集（通过上面的方面操作）
    val_set = ImageFolder('val', transform=transform_val)
    # 封装验证集
    val_loader = DataLoader(dataset=val_set,
                            batch_size=batch_size,
                            shuffle=False)
    # 输出
    return transform_train, train_set, train_loader, transform_val, val_set, val_loader
